{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "import torch\n",
    "from PIL import Image\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils_correspondence import resize\n",
    "from model_utils.extractor_sd import load_model, process_features_and_mask\n",
    "from model_utils.extractor_dino import ViTExtractor\n",
    "from model_utils.projection_network import AggregationNetwork\n",
    "from preprocess_map import set_seed\n",
    "\n",
    "set_seed(42)\n",
    "num_patches = 60\n",
    "sd_model = sd_aug = extractor_vit = None\n",
    "aggre_net = AggregationNetwork(feature_dims=[640,1280,1280,768], projection_dim=768, device='cuda')\n",
    "aggre_net.load_pretrained_weights(torch.load('results_spair/best_856.PTH'))\n",
    "        \n",
    "def get_processed_features(sd_model, sd_aug, aggre_net, extractor_vit, num_patches, img=None, img_path=None):\n",
    "    \n",
    "    if img_path is not None:\n",
    "        feature_base = img_path.replace('JPEGImages', 'features').replace('.jpg', '')\n",
    "        sd_path = f\"{feature_base}_sd.pt\"\n",
    "        dino_path = f\"{feature_base}_dino.pt\"\n",
    "\n",
    "    # extract stable diffusion features\n",
    "    if img_path is not None and os.path.exists(sd_path):\n",
    "        features_sd = torch.load(sd_path)\n",
    "        for k in features_sd:\n",
    "            features_sd[k] = features_sd[k].to('cuda')\n",
    "    else:\n",
    "        if img is None: img = Image.open(img_path).convert('RGB')\n",
    "        img_sd_input = resize(img, target_res=num_patches*16, resize=True, to_pil=True)\n",
    "        features_sd = process_features_and_mask(sd_model, sd_aug, img_sd_input, mask=False, raw=True)\n",
    "        del features_sd['s2']\n",
    "\n",
    "    # extract dinov2 features\n",
    "    if img_path is not None and os.path.exists(dino_path):\n",
    "        features_dino = torch.load(dino_path)\n",
    "    else:\n",
    "        if img is None: img = Image.open(img_path).convert('RGB')\n",
    "        img_dino_input = resize(img, target_res=num_patches*14, resize=True, to_pil=True)\n",
    "        img_batch = extractor_vit.preprocess_pil(img_dino_input)\n",
    "        features_dino = extractor_vit.extract_descriptors(img_batch.cuda(), layer=11, facet='token').permute(0, 1, 3, 2).reshape(1, -1, num_patches, num_patches)\n",
    "\n",
    "    desc_gathered = torch.cat([\n",
    "            features_sd['s3'],\n",
    "            F.interpolate(features_sd['s4'], size=(num_patches, num_patches), mode='bilinear', align_corners=False),\n",
    "            F.interpolate(features_sd['s5'], size=(num_patches, num_patches), mode='bilinear', align_corners=False),\n",
    "            features_dino\n",
    "        ], dim=1)\n",
    "    \n",
    "    desc = aggre_net(desc_gathered) # 1, 768, 60, 60\n",
    "    # normalize the descriptors\n",
    "    norms_desc = torch.linalg.norm(desc, dim=1, keepdim=True)\n",
    "    desc = desc / (norms_desc + 1e-8)\n",
    "    return desc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Pretrained SD and DINOv2 Model\n",
    "This block may take ~4 minutes to run.\n",
    "You can **skip this block** if you only want to visualize the post-processed features of *dataset images* that you have *already pre-extracted*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This block may take ~4 minutes to run\n",
    "# If you only want to visualize the features of dataset images that you have already pre-extracted, you can skip this block\n",
    "\n",
    "sd_model, sd_aug = load_model(diffusion_ver='v1-5', image_size=num_patches*16, num_timesteps=50, block_indices=[2,5,8,11])\n",
    "extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device='cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Correspondence Demo for Dataset Pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = 480\n",
    "img1_path = 'data/SPair-71k/JPEGImages/dog/2010_000899.jpg' # path to the source image\n",
    "img1 = resize(Image.open(img1_path).convert('RGB'), target_res=img_size, resize=True, to_pil=True)\n",
    "\n",
    "img2_path = 'data/SPair-71k/JPEGImages/dog/2011_002398.jpg' # path to the target image\n",
    "img2 = resize(Image.open(img2_path).convert('RGB'), target_res=img_size, resize=True, to_pil=True)\n",
    "\n",
    "# visualize the two images in the same row\n",
    "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
    "for a in ax: a.axis('off')\n",
    "ax[0].imshow(img1)\n",
    "ax[0].set_title('source image')\n",
    "ax[1].imshow(img2)\n",
    "ax[1].set_title('target image')\n",
    "plt.show()\n",
    "\n",
    "feat1 = get_processed_features(sd_model, sd_aug, aggre_net, extractor_vit, num_patches, img=img1, img_path=img1_path)\n",
    "feat2 = get_processed_features(sd_model, sd_aug, aggre_net, extractor_vit, num_patches, img=img2, img_path=img2_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib widget\n",
    "from utils.utils_visualization_demo import Demo\n",
    "\n",
    "demo = Demo([img1,img2], torch.cat([feat1, feat2], dim=0), img_size)\n",
    "demo.plot_img_pairs(fig_size=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Correspondence Demo for Your Own Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = 480\n",
    "img1_path = 'data/images/unicorn.jpg' # path to the source image\n",
    "img1 = resize(Image.open(img1_path).convert('RGB'), target_res=img_size, resize=True, to_pil=True)\n",
    "\n",
    "img2_path = 'data/images/antelope.jpg' # path to the target image\n",
    "img2 = resize(Image.open(img2_path).convert('RGB'), target_res=img_size, resize=True, to_pil=True)\n",
    "\n",
    "# visualize the two images in the same row\n",
    "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
    "for a in ax: a.axis('off')\n",
    "ax[0].imshow(img1)\n",
    "ax[0].set_title('source image')\n",
    "ax[1].imshow(img2)\n",
    "ax[1].set_title('target image')\n",
    "plt.show()\n",
    "\n",
    "feat1 = get_processed_features(sd_model, sd_aug, aggre_net, extractor_vit, num_patches, img=img1)\n",
    "feat2 = get_processed_features(sd_model, sd_aug, aggre_net, extractor_vit, num_patches, img=img2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib widget\n",
    "from utils.utils_visualization_demo import Demo\n",
    "\n",
    "demo = Demo([img1,img2], torch.cat([feat1, feat2], dim=0), img_size)\n",
    "demo.plot_img_pairs(fig_size=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sd_dino",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
